import logging
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import List
from scipy import ndimage
from PIL import Image

from cvpods.layers import ShapeSpec, get_norm, matrix_nms
from cvpods.modeling.nn_utils.weight_init import normal_init
from cvpods.structures import ImageList, Instances, BitMasks
from cvpods.utils import log_first_n
from cvpods.modeling.losses import sigmoid_focal_loss_jit, dice_loss
from .solo import SOLO, SOLOHead, multi_apply, points_nms


class DecoupledSOLO(SOLO):
    """
    Implement Decoupled SOLO.
    See: https://arxiv.org/pdf/1512.02325.pdf for more details.
    """

    def __init__(self, cfg):
        super().__init__(cfg)

    def _init_head(self, cfg, feature_shapes):
        assert self.head_type == 'DecoupledSOLOHead'
        self.head = DecoupledSOLOHead(cfg, feature_shapes)

    def forward(self, batched_inputs):
        """
        Args:
            batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
                Each item in the list contains the inputs for one image.
                For now, each item in the list is a dict that contains:

                * image: Tensor, image in (C, H, W) format.
                * instances: Instances

                Other information that's included in the original dicts, such as:

                * "height", "width" (int): the output resolution of the model, used in inference.
                    See :meth:`postprocess` for details.
        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a tensor storing the loss. Used during training only.
        """
        images = self.preprocess_image(batched_inputs)
        if "instances" in batched_inputs[0]:
            gt_instances = [
                x["instances"].to(self.device) for x in batched_inputs
            ]
        elif "targets" in batched_inputs[0]:
            log_first_n(
                logging.WARN,
                "'targets' in the model inputs is now renamed to 'instances'!",
                n=10)
            gt_instances = [
                x["targets"].to(self.device) for x in batched_inputs
            ]
        else:
            gt_instances = None

        features = self.backbone(images.tensor)
        features = [features[f] for f in self.in_features]

        if self.training:
            ins_preds_x, ins_preds_y, cate_preds = self.head(features, eval=False)
            featmap_sizes = [featmap.size()[-2:] for featmap in ins_preds_x]
            ins_label_list, cate_label_list, ins_ind_label_list, ins_ind_label_list_xy = \
                self.get_ground_truth(gt_instances, featmap_sizes)
            return self.losses(ins_preds_x, ins_preds_y, cate_preds, ins_label_list,
                               cate_label_list, ins_ind_label_list, ins_ind_label_list_xy)
        else:
            ins_preds_x, ins_preds_y, cate_preds = self.head(features, eval=True)
            results = self.inference(ins_preds_x, ins_preds_y, cate_preds, batched_inputs)
            processed_results = [{"instances": r} for r in results]
            return processed_results

    def losses(self,
               ins_preds_x,
               ins_preds_y,
               cate_preds,
               ins_label_list,
               cate_label_list,
               ins_ind_label_list,
               ins_ind_label_list_xy):
        # ins, per level
        ins_labels = []  # per level
        for ins_labels_level, ins_ind_labels_level in \
                zip(zip(*ins_label_list), zip(*ins_ind_label_list)):
            ins_labels_per_level = []
            for ins_labels_level_img, ins_ind_labels_level_img in \
                    zip(ins_labels_level, ins_ind_labels_level):
                ins_labels_per_level.append(
                    ins_labels_level_img[ins_ind_labels_level_img, ...]
                )
            ins_labels.append(torch.cat(ins_labels_per_level))

        ins_preds_x_final = []
        for ins_preds_level_x, ins_ind_labels_level in \
                zip(ins_preds_x, zip(*ins_ind_label_list_xy)):
            ins_preds_x_final_per_level = []
            for ins_preds_level_img_x, ins_ind_labels_level_img in \
                    zip(ins_preds_level_x, ins_ind_labels_level):
                ins_preds_x_final_per_level.append(
                    ins_preds_level_img_x[ins_ind_labels_level_img[:, 1], ...]
                )
            ins_preds_x_final.append(torch.cat(ins_preds_x_final_per_level))

        ins_preds_y_final = []
        for ins_preds_level_y, ins_ind_labels_level in \
                zip(ins_preds_y, zip(*ins_ind_label_list_xy)):
            ins_preds_y_final_per_level = []
            for ins_preds_level_img_y, ins_ind_labels_level_img in \
                    zip(ins_preds_level_y, ins_ind_labels_level):
                ins_preds_y_final_per_level.append(
                    ins_preds_level_img_y[ins_ind_labels_level_img[:, 0], ...]
                )
            ins_preds_y_final.append(torch.cat(ins_preds_y_final_per_level))

        num_ins = 0.
        # dice loss, per_level
        loss_ins = []
        for input_x, input_y, target in zip(ins_preds_x_final, ins_preds_y_final, ins_labels):
            mask_n = input_x.size(0)
            if mask_n == 0:
                continue
            num_ins += mask_n
            input = (input_x.sigmoid()) * (input_y.sigmoid())
            target = target.float() / 255.
            loss_ins.append(dice_loss(input, target))

        loss_ins = torch.cat(loss_ins).mean()
        loss_ins = loss_ins * self.loss_ins_weight

        # cate
        cate_preds = [
            cate_pred.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
            for cate_pred in cate_preds
        ]
        cate_preds = torch.cat(cate_preds)

        cate_labels = []
        for cate_labels_level in zip(*cate_label_list):
            cate_labels_per_level = []
            for cate_labels_level_img in cate_labels_level:
                cate_labels_per_level.append(
                    cate_labels_level_img.flatten()
                )
            cate_labels.append(torch.cat(cate_labels_per_level))
        flatten_cate_labels = torch.cat(cate_labels)
        foreground_idxs = flatten_cate_labels != self.num_classes
        cate_labels = torch.zeros_like(cate_preds)
        cate_labels[foreground_idxs, flatten_cate_labels[foreground_idxs]] = 1

        loss_cate = self.loss_cat_weight * sigmoid_focal_loss_jit(
            cate_preds,
            cate_labels,
            alpha=self.loss_cat_alpha,
            gamma=self.loss_cat_gamma,
            reduction="sum",
        ) / max(1, num_ins)
        return dict(loss_ins=loss_ins, loss_cate=loss_cate)

    @torch.no_grad()
    def get_ground_truth(self, gt_instances, featmap_sizes):
        """
        Args:
            gt_instances (list[Instances]): a list of N `Instances`s. The i-th
                `Instances` contains the ground-truth per-instance annotations
                for the i-th input image.  Specify `targets` during training only.
            featmap_sizes (list[]): a list of #level elements. Each is a
                tuple of #feature level feature map size.
        """
        ins_label_list, cate_label_list, ins_ind_label_list, ins_ind_label_list_xy = multi_apply(
            self.solo_target_single_image,
            gt_instances,
            featmap_sizes=featmap_sizes)
        return ins_label_list, cate_label_list, ins_ind_label_list, ins_ind_label_list_xy

    @torch.no_grad()
    def solo_target_single_image(self, gt_instance, featmap_sizes):
        """
        Prepare ground truth for single image.
        """
        device = self.device
        gt_bboxes_raw = gt_instance.gt_boxes
        gt_labels_raw = gt_instance.gt_classes
        gt_masks_raw = gt_instance.gt_masks

        # ins
        gt_areas = torch.sqrt(gt_bboxes_raw.area())

        ins_label_list = []  # per level
        cate_label_list = []  # per level
        ins_ind_label_list = []  # per level
        ins_ind_label_list_xy = []  # per level
        for (lower_bound, upper_bound), stride, featmap_size, num_grid in zip(
                self.scale_ranges, self.feature_strides, featmap_sizes, self.seg_num_grids):
            ins_label = torch.zeros([num_grid ** 2, featmap_size[0],
                                     featmap_size[1]], dtype=torch.uint8, device=device)
            cate_label = torch.full(
                [num_grid, num_grid], self.num_classes, dtype=torch.int64, device=device)
            ins_ind_label = torch.zeros([num_grid ** 2], dtype=torch.bool, device=device)

            hit_indices = ((gt_areas >= lower_bound) & (
                gt_areas <= upper_bound)).nonzero().flatten()

            if len(hit_indices) == 0:
                ins_label = torch.zeros([1, featmap_size[0], featmap_size[1]], dtype=torch.uint8,
                                        device=device)
                ins_label_list.append(ins_label)
                cate_label_list.append(cate_label)
                # default is False
                ins_ind_label = torch.zeros([1], dtype=torch.bool, device=device)
                ins_ind_label_list.append(ins_ind_label)
                ins_ind_label_list_xy.append(torch.zeros([0, 2], dtype=torch.int64, device=device))
                continue

            gt_bboxes = gt_bboxes_raw[hit_indices]
            gt_labels = gt_labels_raw[hit_indices]
            gt_masks = gt_masks_raw[hit_indices]
            # move mask to cpu and convert to ndarray for compute gt ins center
            gt_masks = gt_masks.tensor.to("cpu").numpy()

            half_ws = 0.5 * (gt_bboxes.tensor[:, 2] - gt_bboxes.tensor[:, 0]) * self.sigma
            half_hs = 0.5 * (gt_bboxes.tensor[:, 3] - gt_bboxes.tensor[:, 1]) * self.sigma

            output_stride = stride / 2
            # For each mask
            for seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws):
                if seg_mask.sum() < 10:
                    continue
                # mass center
                upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4)
                center_h, center_w = ndimage.measurements.center_of_mass(seg_mask)
                coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
                coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))

                # left, top, right, down
                top_box = max(0, int(((center_h - half_h) / upsampled_size[0]) // (1. / num_grid)))
                down_box = min(num_grid - 1, int(((center_h + half_h) /
                                                  upsampled_size[0]) // (1. / num_grid)))
                left_box = max(0, int(((center_w - half_w) / upsampled_size[1]) // (1. / num_grid)))
                right_box = min(num_grid - 1, int(((center_w + half_w) /
                                                   upsampled_size[1]) // (1. / num_grid)))

                top = max(top_box, coord_h - 1)
                down = min(down_box, coord_h + 1)
                left = max(coord_w - 1, left_box)
                right = min(right_box, coord_w + 1)

                cate_label[top:(down + 1), left:(right + 1)] = gt_label
                # ins
                scale = 1. / output_stride
                h, w = seg_mask.shape[-2:]
                new_h, new_w = int(h * scale + 0.5), int(w * scale + 0.5)
                seg_mask = Image.fromarray(seg_mask)
                seg_mask = seg_mask.resize((new_w, new_h), Image.BILINEAR)
                seg_mask = np.array(seg_mask)
                seg_mask = torch.from_numpy(seg_mask)
                for i in range(top, down + 1):
                    for j in range(left, right + 1):
                        label = int(i * num_grid + j)
                        ins_label[label, :seg_mask.shape[0], :seg_mask.shape[1]] = seg_mask
                        ins_ind_label[label] = True

            # Instance mask
            ins_label = ins_label[ins_ind_label]
            ins_label_list.append(ins_label)
            # Instance category
            cate_label_list.append(cate_label)
            # Instance index
            ins_ind_label = ins_ind_label[ins_ind_label]
            ins_ind_label_list.append(ins_ind_label)
            # Instance coordinate
            foreground_idxs = (cate_label != self.num_classes)
            ins_ind_label_list_xy.append(foreground_idxs.nonzero())

        return ins_label_list, cate_label_list, ins_ind_label_list, ins_ind_label_list_xy

    @torch.no_grad()
    def inference(self, ins_preds_x, ins_preds_y, cate_preds, batched_inputs):
        assert len(ins_preds_x) == len(cate_preds)
        num_levels = len(cate_preds)
        featmap_size = ins_preds_x[0].size()[-2:]

        results = []
        for img_id, batched_input in enumerate(batched_inputs):
            cate_pred_list = []
            seg_pred_list_x = []
            seg_pred_list_y = []
            for i in range(num_levels):
                cate_pred_list.append(
                    cate_preds[i][img_id].view(-1, self.num_classes).detach()
                )
                seg_pred_list_x.append(
                    ins_preds_x[i][img_id].detach()
                )
                seg_pred_list_y.append(
                    ins_preds_y[i][img_id].detach()
                )
            cate_pred_list = torch.cat(cate_pred_list, dim=0)
            seg_pred_list_x = torch.cat(seg_pred_list_x, dim=0)
            seg_pred_list_y = torch.cat(seg_pred_list_y, dim=0)

            img_shape = batched_input["instances"].image_size
            ori_shape = (batched_input["height"], batched_input["width"])

            results_per_image = self.inference_single_image(
                cate_pred_list, seg_pred_list_x, seg_pred_list_y,
                featmap_size, img_shape, ori_shape)
            results.append(results_per_image)
        return results

    @torch.no_grad()
    def inference_single_image(self,
                               cate_preds,
                               seg_preds_x,
                               seg_preds_y,
                               featmap_size,
                               img_shape,
                               ori_shape):
        result = Instances(ori_shape)

        # overall info.
        h, w = img_shape
        upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4)

        # trans trans_diff.
        trans_size = torch.Tensor(self.seg_num_grids).pow(2).cumsum(0).long()
        trans_diff = torch.ones(trans_size[-1].item(), device=self.device).long()
        num_grids = torch.ones(trans_size[-1].item(), device=self.device).long()
        seg_size = torch.Tensor(self.seg_num_grids).cumsum(0).long()
        seg_diff = torch.ones(trans_size[-1].item(), device=self.device).long()
        strides = torch.ones(trans_size[-1].item(), device=self.device)

        n_stage = len(self.seg_num_grids)
        trans_diff[:trans_size[0]] *= 0
        seg_diff[:trans_size[0]] *= 0
        num_grids[:trans_size[0]] *= self.seg_num_grids[0]
        strides[:trans_size[0]] *= self.feature_strides[0]

        for ind_ in range(1, n_stage):
            trans_diff[trans_size[ind_ - 1]:trans_size[ind_]] *= trans_size[ind_ - 1]
            seg_diff[trans_size[ind_ - 1]:trans_size[ind_]] *= seg_size[ind_ - 1]
            num_grids[trans_size[ind_ - 1]:trans_size[ind_]] *= self.seg_num_grids[ind_]
            strides[trans_size[ind_ - 1]:trans_size[ind_]] *= self.feature_strides[ind_]

        # process.
        inds = (cate_preds > self.score_threshold)
        # category scores.
        cate_scores = cate_preds[inds]

        # category labels.
        inds = inds.nonzero()
        trans_diff = torch.index_select(trans_diff, dim=0, index=inds[:, 0])
        seg_diff = torch.index_select(seg_diff, dim=0, index=inds[:, 0])
        num_grids = torch.index_select(num_grids, dim=0, index=inds[:, 0])
        strides = torch.index_select(strides, dim=0, index=inds[:, 0])

        y_inds = (inds[:, 0] - trans_diff) // num_grids
        x_inds = (inds[:, 0] - trans_diff) % num_grids
        y_inds += seg_diff
        x_inds += seg_diff

        cate_labels = inds[:, 1]
        seg_masks_soft = seg_preds_x[x_inds, ...] * seg_preds_y[y_inds, ...]
        seg_masks = seg_masks_soft > self.mask_threshold
        sum_masks = seg_masks.sum((1, 2)).float()

        # filter.
        keep = sum_masks > strides
        if keep.sum() == 0:
            return result

        seg_masks_soft = seg_masks_soft[keep, ...]
        seg_masks = seg_masks[keep, ...]
        cate_scores = cate_scores[keep]
        sum_masks = sum_masks[keep]
        cate_labels = cate_labels[keep]

        # mask scoring
        seg_score = (seg_masks_soft * seg_masks.float()).sum((1, 2)) / sum_masks
        cate_scores *= seg_score

        if len(cate_scores) == 0:
            return result

        # sort and keep top nms_pre
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > self.nms_per_image:
            sort_inds = sort_inds[: self.nms_per_image]
        seg_masks_soft = seg_masks_soft[sort_inds, :, :]
        seg_masks = seg_masks[sort_inds, :, :]
        cate_scores = cate_scores[sort_inds]
        sum_masks = sum_masks[sort_inds]
        cate_labels = cate_labels[sort_inds]

        # Matrix NMS
        cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores,
                                 kernel=self.nms_kernel, sigma=self.nms_sigma, sum_masks=sum_masks)

        # filter.
        keep = cate_scores >= self.update_threshold
        seg_masks_soft = seg_masks_soft[keep, :, :]
        cate_scores = cate_scores[keep]
        cate_labels = cate_labels[keep]

        # sort and keep top_k
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > self.max_detections_per_image:
            sort_inds = sort_inds[: self.max_detections_per_image]
        seg_masks_soft = seg_masks_soft[sort_inds, :, :]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]

        seg_masks_soft = F.interpolate(seg_masks_soft.unsqueeze(0),
                                       size=upsampled_size_out,
                                       mode='bilinear')[:, :, :h, :w]
        seg_masks = F.interpolate(seg_masks_soft,
                                  size=ori_shape,
                                  mode='bilinear').squeeze(0)
        seg_masks = seg_masks > self.mask_threshold

        seg_masks = BitMasks(seg_masks)
        result.pred_masks = seg_masks
        result.pred_boxes = seg_masks.get_bounding_boxes()
        result.scores = cate_scores
        result.pred_classes = cate_labels
        return result

    def preprocess_image(self, batched_inputs):
        """
        Normalize, pad and batch the input images.
        """
        images = [x["image"].to(self.device) for x in batched_inputs]
        images = [self.normalizer(x) for x in images]
        images = ImageList.from_tensors(images,
                                        self.backbone.size_divisibility)
        return images


class DecoupledSOLOHead(SOLOHead):
    """
    The head used in SOLO for instance segmentation.
    """

    def __init__(self, cfg, input_shape: List[ShapeSpec]):
        super().__init__(cfg, input_shape)

    def _init_layers(self):
        ins_convs_x = []
        ins_convs_y = []
        cate_convs = []
        for i in range(self.stacked_convs):
            # Mask branch
            chn = self.in_channels + 1 if i == 0 else self.seg_feat_channels
            ins_convs_x.append(
                nn.Conv2d(chn,
                          self.seg_feat_channels,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          bias=False if self.norm else True)
            )
            ins_convs_y.append(
                nn.Conv2d(chn,
                          self.seg_feat_channels,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          bias=False if self.norm else True)
            )

            if self.norm:
                ins_convs_x.append(get_norm(self.norm, self.seg_feat_channels))
                ins_convs_y.append(get_norm(self.norm, self.seg_feat_channels))

            ins_convs_x.append(nn.ReLU(inplace=True))
            ins_convs_y.append(nn.ReLU(inplace=True))

            # Category branch
            chn = self.in_channels if i == 0 else self.seg_feat_channels
            cate_convs.append(
                nn.Conv2d(chn,
                          self.seg_feat_channels,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          bias=False if self.norm else True)
            )
            if self.norm:
                cate_convs.append(get_norm(self.norm, self.seg_feat_channels))
            cate_convs.append(nn.ReLU(inplace=True))

        self.ins_convs_x = nn.Sequential(*ins_convs_x)
        self.ins_convs_y = nn.Sequential(*ins_convs_y)
        self.cate_convs = nn.Sequential(*cate_convs)

        self.solo_ins_list_x = nn.ModuleList()
        self.solo_ins_list_y = nn.ModuleList()
        for seg_num_grid in self.seg_num_grids:
            self.solo_ins_list_x.append(
                nn.Conv2d(
                    self.seg_feat_channels,
                    seg_num_grid,
                    kernel_size=3,
                    stride=1,
                    padding=1)
            )
            self.solo_ins_list_y.append(
                nn.Conv2d(
                    self.seg_feat_channels,
                    seg_num_grid,
                    kernel_size=3,
                    stride=1,
                    padding=1)
            )

        self.solo_cate = nn.Conv2d(
            self.seg_feat_channels,
            self.num_classes,
            kernel_size=3,
            stride=1,
            padding=1,
        )

    def _init_weights(self):
        for modules in [self.ins_convs_x, self.ins_convs_y, self.cate_convs]:
            for m in modules.modules():
                if isinstance(m, nn.Conv2d):
                    normal_init(m, std=0.01)
        # Use prior in model initialization to improve stability
        bias_value = -math.log((1 - self.prior_prob) / self.prior_prob)
        for modules in [self.solo_ins_list_x, self.solo_ins_list_y, self.solo_cate]:
            for m in modules.modules():
                if isinstance(m, nn.Conv2d):
                    normal_init(m, std=0.01, bias=bias_value)

    def forward(self, features, eval=False):
        new_features = self.split_features(features)
        featmap_sizes = [featmap.size()[-2:] for featmap in new_features]
        upsampled_size = (featmap_sizes[0][0] * 2, featmap_sizes[0][1] * 2)
        ins_pred_x, ins_pred_y, cate_pred = multi_apply(
            self.forward_single_level,
            new_features,
            list(range(len(self.seg_num_grids))),
            eval=eval,
            upsampled_size=upsampled_size
        )
        return ins_pred_x, ins_pred_y, cate_pred

    def forward_single_level(self, x, idx, eval=False, upsampled_size=None):
        ins_feat = x
        cate_feat = x
        # Ins branch
        # concat coord
        x_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device)
        y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device)
        y, x = torch.meshgrid(y_range, x_range)
        y = y.expand([ins_feat.shape[0], 1, -1, -1])
        x = x.expand([ins_feat.shape[0], 1, -1, -1])
        ins_feat_x = torch.cat([ins_feat, x], 1)
        ins_feat_y = torch.cat([ins_feat, y], 1)

        ins_feat_x = self.ins_convs_x(ins_feat_x)
        ins_feat_y = self.ins_convs_y(ins_feat_y)

        ins_feat_x = F.interpolate(ins_feat_x, scale_factor=2, mode='bilinear')
        ins_feat_y = F.interpolate(ins_feat_y, scale_factor=2, mode='bilinear')

        ins_pred_x = self.solo_ins_list_x[idx](ins_feat_x)
        ins_pred_y = self.solo_ins_list_y[idx](ins_feat_y)

        # Cate branch
        seg_num_grid = self.seg_num_grids[idx]
        # Align
        cate_feat = F.interpolate(cate_feat, size=seg_num_grid, mode='bilinear')
        cate_feat = self.cate_convs(cate_feat)
        cate_pred = self.solo_cate(cate_feat)

        if eval:
            ins_pred_x = F.interpolate(ins_pred_x.sigmoid(), size=upsampled_size, mode='bilinear')
            ins_pred_y = F.interpolate(ins_pred_y.sigmoid(), size=upsampled_size, mode='bilinear')
            cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1)
        return ins_pred_x, ins_pred_y, cate_pred
